Skip to content

Conversation

quic-meetkuma
Copy link
Contributor

@quic-meetkuma quic-meetkuma commented Jun 23, 2025

Disabling gradient is necessary when using gradient_accumulation_step > 1 with ddp enabled.
Currently, we are syncing gradient at every loss.backward() call, which is called at all steps. When using gradient accumulation, the weight update during opt.step() step. Only during that step, the gradients across each devices should sync with each other.

with model.no_sync() --> context manager solves this issue.

Here, we are not using it but instead setting ddp_model.require_backward_grad_sync to True or False depending on which step we are.

@quic-meetkuma quic-meetkuma changed the title [QEff. Finetune]: Added support to sync gradients across devices during backward step only. [QEff. Finetune]: Added support to sync gradients across devices during optimizer step only. Jul 9, 2025
@quic-swatia quic-swatia merged commit 3aaa2d8 into quic:main Jul 9, 2025
4 checks passed
quic-amitraj pushed a commit that referenced this pull request Jul 10, 2025
…ng optimizer step only. (#477)

Disabling gradient is necessary when using gradient_accumulation_step >
1 with ddp enabled.
Currently, we are syncing gradient at every loss.backward() call, which
is called at all steps. When using gradient accumulation, the weight
update during opt.step() step. Only during that step, the gradients
across each devices should sync with each other.

with model.no_sync() --> context manager solves this issue.

Here, we are not using it but instead setting
ddp_model.require_backward_grad_sync to True or False depending on which
step we are.

---------

Signed-off-by: Meet Patel <[email protected]>
Signed-off-by: meetkuma <[email protected]>
Signed-off-by: Amit Raj <[email protected]>
quic-amitraj pushed a commit that referenced this pull request Jul 10, 2025
…ng optimizer step only. (#477)

Disabling gradient is necessary when using gradient_accumulation_step >
1 with ddp enabled.
Currently, we are syncing gradient at every loss.backward() call, which
is called at all steps. When using gradient accumulation, the weight
update during opt.step() step. Only during that step, the gradients
across each devices should sync with each other.

with model.no_sync() --> context manager solves this issue.

Here, we are not using it but instead setting
ddp_model.require_backward_grad_sync to True or False depending on which
step we are.

---------

Signed-off-by: Meet Patel <[email protected]>
Signed-off-by: meetkuma <[email protected]>
Signed-off-by: Amit Raj <[email protected]>
quic-amitraj pushed a commit that referenced this pull request Jul 10, 2025
…ng optimizer step only. (#477)

Disabling gradient is necessary when using gradient_accumulation_step >
1 with ddp enabled.
Currently, we are syncing gradient at every loss.backward() call, which
is called at all steps. When using gradient accumulation, the weight
update during opt.step() step. Only during that step, the gradients
across each devices should sync with each other.

with model.no_sync() --> context manager solves this issue.

Here, we are not using it but instead setting
ddp_model.require_backward_grad_sync to True or False depending on which
step we are.

---------

Signed-off-by: Meet Patel <[email protected]>
Signed-off-by: meetkuma <[email protected]>
quic-dhirajku pushed a commit to quic-dhirajku/efficient-transformers that referenced this pull request Aug 4, 2025
…ng optimizer step only. (quic#477)

Disabling gradient is necessary when using gradient_accumulation_step >
1 with ddp enabled.
Currently, we are syncing gradient at every loss.backward() call, which
is called at all steps. When using gradient accumulation, the weight
update during opt.step() step. Only during that step, the gradients
across each devices should sync with each other.

with model.no_sync() --> context manager solves this issue.

Here, we are not using it but instead setting
ddp_model.require_backward_grad_sync to True or False depending on which
step we are.

---------

Signed-off-by: Meet Patel <[email protected]>
Signed-off-by: meetkuma <[email protected]>
quic-dhirajku pushed a commit to quic-dhirajku/efficient-transformers that referenced this pull request Aug 4, 2025
…ng optimizer step only. (quic#477)

Disabling gradient is necessary when using gradient_accumulation_step >
1 with ddp enabled.
Currently, we are syncing gradient at every loss.backward() call, which
is called at all steps. When using gradient accumulation, the weight
update during opt.step() step. Only during that step, the gradients
across each devices should sync with each other.

with model.no_sync() --> context manager solves this issue.

Here, we are not using it but instead setting
ddp_model.require_backward_grad_sync to True or False depending on which
step we are.

---------

Signed-off-by: Meet Patel <[email protected]>
Signed-off-by: meetkuma <[email protected]>
Signed-off-by: Dhiraj Kumar Sah <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants